{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers, optimizers, datasets\n",
    "from tensorflow.keras.layers import Dense, Dropout, Flatten\n",
    "from tensorflow.keras.layers import Conv2D, MaxPooling2D\n",
    "\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2'}\n",
    "\n",
    "def mnist_dataset():\n",
    "    (x, y), (x_test, y_test) = datasets.mnist.load_data()\n",
    "    x = x.reshape(x.shape[0], 28, 28,1)\n",
    "    x_test = x_test.reshape(x_test.shape[0], 28, 28,1)\n",
    "    \n",
    "    ds = tf.data.Dataset.from_tensor_slices((x, y))\n",
    "    ds = ds.map(prepare_mnist_features_and_labels)\n",
    "    ds = ds.take(20000).shuffle(20000).batch(32)\n",
    "    \n",
    "    test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n",
    "    test_ds = test_ds.map(prepare_mnist_features_and_labels)\n",
    "    test_ds = test_ds.take(20000).shuffle(20000).batch(20000)\n",
    "    return ds, test_ds\n",
    "\n",
    "def prepare_mnist_features_and_labels(x, y):\n",
    "    x = tf.cast(x, tf.float32) / 255.0\n",
    "    y = tf.cast(y, tf.int64)\n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 建立模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "class myConvModel(keras.Model):\n",
    "    def __init__(self):\n",
    "        super(myConvModel, self).__init__()\n",
    "        self.l1_conv = Conv2D(32, (5, 5), activation='relu', padding='same')\n",
    "        self.l2_conv = Conv2D(64, (5, 5), activation='relu', padding='same')\n",
    "        self.pool = MaxPooling2D(pool_size=(2, 2), strides=2)\n",
    "        self.flat = Flatten()\n",
    "        self.dense1 = layers.Dense(100, activation='tanh')\n",
    "        self.dense2 = layers.Dense(10)\n",
    "    @tf.function\n",
    "    def call(self, x):\n",
    "        h1 = self.l1_conv(x)\n",
    "        h1_pool = self.pool(h1)\n",
    "        h2 = self.l2_conv(h1_pool)\n",
    "        h2_pool = self.pool(h2)\n",
    "        flat_h = self.flat(h2_pool)\n",
    "        dense1 = self.dense1(flat_h)\n",
    "        logits = self.dense2(dense1)\n",
    "        return logits\n",
    "\n",
    "model = myConvModel()\n",
    "optimizer = optimizers.Adam()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义loss以及train loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def compute_loss(logits, labels):\n",
    "    return tf.reduce_mean(\n",
    "        tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "            logits=logits, labels=labels))\n",
    "\n",
    "@tf.function\n",
    "def compute_accuracy(logits, labels):\n",
    "    predictions = tf.argmax(logits, axis=1)\n",
    "    return tf.reduce_mean(tf.cast(tf.equal(predictions, labels), tf.float32))\n",
    "\n",
    "@tf.function\n",
    "def train_one_step(model, optimizer, x, y):\n",
    "\n",
    "    with tf.GradientTape() as tape:\n",
    "        logits = model(x)\n",
    "        loss = compute_loss(logits, y)\n",
    "\n",
    "    # compute gradient\n",
    "    grads = tape.gradient(loss, model.trainable_variables)\n",
    "    # update to weights\n",
    "    optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
    "\n",
    "    accuracy = compute_accuracy(logits, y)\n",
    "\n",
    "    # loss and accuracy is scalar tensor\n",
    "    return loss, accuracy\n",
    "\n",
    "@tf.function\n",
    "def test_step(model, x, y):\n",
    "    logits = model(x)\n",
    "    loss = compute_loss(logits, y)\n",
    "    accuracy = compute_accuracy(logits, y)\n",
    "    return loss, accuracy\n",
    "\n",
    "def train(epoch, model, optimizer, ds):\n",
    "    loss = 0.0\n",
    "    accuracy = 0.0\n",
    "    for step, (x, y) in enumerate(ds):\n",
    "        loss, accuracy = train_one_step(model, optimizer, x, y)\n",
    "\n",
    "        if step % 500 == 0:\n",
    "            print('epoch', epoch, ': loss', loss.numpy(), '; accuracy', accuracy.numpy())\n",
    "\n",
    "    return loss, accuracy\n",
    "def test(model, ds):\n",
    "    loss = 0.0\n",
    "    accuracy = 0.0\n",
    "    for step, (x, y) in enumerate(ds):\n",
    "        loss, accuracy = test_step(model, x, y)\n",
    "\n",
    "        \n",
    "    print('test loss', loss.numpy(), '; accuracy', accuracy.numpy())\n",
    "\n",
    "    return loss, accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0 : loss 2.307023 ; accuracy 0.09375\n",
      "epoch 0 : loss 0.16172078 ; accuracy 0.96875\n",
      "epoch 1 : loss 0.15242997 ; accuracy 0.9375\n",
      "epoch 1 : loss 0.24359244 ; accuracy 0.9375\n",
      "test loss 0.057900324 ; accuracy 0.9812\n"
     ]
    }
   ],
   "source": [
    "train_ds, test_ds = mnist_dataset()\n",
    "for epoch in range(2):\n",
    "    loss, accuracy = train(epoch, model, optimizer, train_ds)\n",
    "loss, accuracy = test(model, test_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
